{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2024 Microsoft Corporation.\n",
    "# Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from graphrag.config.enums import ModelType\n",
    "from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
    "from graphrag.config.models.language_model_config import LanguageModelConfig\n",
    "from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
    "from graphrag.language_model.manager import ModelManager\n",
    "from graphrag.query.indexer_adapters import (\n",
    "    read_indexer_entities,\n",
    "    read_indexer_relationships,\n",
    "    read_indexer_report_embeddings,\n",
    "    read_indexer_reports,\n",
    "    read_indexer_text_units,\n",
    ")\n",
    "from graphrag.query.structured_search.drift_search.drift_context import (\n",
    "    DRIFTSearchContextBuilder,\n",
    ")\n",
    "from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
    "from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
    "from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
    "\n",
    "INPUT_DIR = \"./inputs/operation dulce\"\n",
    "LANCEDB_URI = f\"{INPUT_DIR}/lancedb\"\n",
    "\n",
    "COMMUNITY_REPORT_TABLE = \"community_reports\"\n",
    "COMMUNITY_TABLE = \"communities\"\n",
    "ENTITY_TABLE = \"entities\"\n",
    "RELATIONSHIP_TABLE = \"relationships\"\n",
    "COVARIATE_TABLE = \"covariates\"\n",
    "TEXT_UNIT_TABLE = \"text_units\"\n",
    "COMMUNITY_LEVEL = 2\n",
    "\n",
    "\n",
    "# read nodes table to get community and degree data\n",
    "entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
    "community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
    "\n",
    "print(f\"Entity df columns: {entity_df.columns}\")\n",
    "\n",
    "entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)\n",
    "\n",
    "# load description embeddings to an in-memory lancedb vectorstore\n",
    "# to connect to a remote db, specify url and port values.\n",
    "description_embedding_store = LanceDBVectorStore(\n",
    "    vector_store_schema_config=VectorStoreSchemaConfig(\n",
    "        index_name=\"default-entity-description\"\n",
    "    ),\n",
    ")\n",
    "description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
    "\n",
    "full_content_embedding_store = LanceDBVectorStore(\n",
    "    vector_store_schema_config=VectorStoreSchemaConfig(\n",
    "        index_name=\"default-community-full_content\"\n",
    "    )\n",
    ")\n",
    "full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
    "\n",
    "print(f\"Entity count: {len(entity_df)}\")\n",
    "entity_df.head()\n",
    "\n",
    "relationship_df = pd.read_parquet(f\"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet\")\n",
    "relationships = read_indexer_relationships(relationship_df)\n",
    "\n",
    "print(f\"Relationship count: {len(relationship_df)}\")\n",
    "relationship_df.head()\n",
    "\n",
    "text_unit_df = pd.read_parquet(f\"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet\")\n",
    "text_units = read_indexer_text_units(text_unit_df)\n",
    "\n",
    "print(f\"Text unit records: {len(text_unit_df)}\")\n",
    "text_unit_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
    "\n",
    "chat_config = LanguageModelConfig(\n",
    "    api_key=api_key,\n",
    "    type=ModelType.Chat,\n",
    "    model_provider=\"openai\",\n",
    "    model=\"gpt-4.1\",\n",
    "    max_retries=20,\n",
    ")\n",
    "chat_model = ModelManager().get_or_create_chat_model(\n",
    "    name=\"local_search\",\n",
    "    model_type=ModelType.Chat,\n",
    "    config=chat_config,\n",
    ")\n",
    "\n",
    "tokenizer = get_tokenizer(chat_config)\n",
    "\n",
    "embedding_config = LanguageModelConfig(\n",
    "    api_key=api_key,\n",
    "    type=ModelType.Embedding,\n",
    "    model_provider=\"openai\",\n",
    "    model=\"text-embedding-3-small\",\n",
    "    max_retries=20,\n",
    ")\n",
    "\n",
    "text_embedder = ModelManager().get_or_create_embedding_model(\n",
    "    name=\"local_search_embedding\",\n",
    "    model_type=ModelType.Embedding,\n",
    "    config=embedding_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_community_reports(\n",
    "    input_dir: str,\n",
    "    community_report_table: str = COMMUNITY_REPORT_TABLE,\n",
    "):\n",
    "    \"\"\"Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.\"\"\"\n",
    "    input_path = Path(input_dir) / f\"{community_report_table}.parquet\"\n",
    "    return pd.read_parquet(input_path)\n",
    "\n",
    "\n",
    "report_df = read_community_reports(INPUT_DIR)\n",
    "reports = read_indexer_reports(\n",
    "    report_df,\n",
    "    community_df,\n",
    "    COMMUNITY_LEVEL,\n",
    "    content_embedding_col=\"full_content_embeddings\",\n",
    ")\n",
    "read_indexer_report_embeddings(reports, full_content_embedding_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "drift_params = DRIFTSearchConfig(\n",
    "    temperature=0,\n",
    "    max_tokens=12_000,\n",
    "    primer_folds=1,\n",
    "    drift_k_followups=3,\n",
    "    n_depth=3,\n",
    "    n=1,\n",
    ")\n",
    "\n",
    "context_builder = DRIFTSearchContextBuilder(\n",
    "    model=chat_model,\n",
    "    text_embedder=text_embedder,\n",
    "    entities=entities,\n",
    "    relationships=relationships,\n",
    "    reports=reports,\n",
    "    entity_text_embeddings=description_embedding_store,\n",
    "    text_units=text_units,\n",
    "    tokenizer=tokenizer,\n",
    "    config=drift_params,\n",
    ")\n",
    "\n",
    "search = DRIFTSearch(\n",
    "    model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resp = await search.search(\"Who is agent Mercer?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resp.response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(resp.context_data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graphrag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
